-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][linalg] Fix Linalg runtime verification pass to handle tensors with dimensions of size 0 #163791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Fix Linalg runtime verification pass to handle tensors with dimensions of size 0 #163791
Conversation
|
@llvm/pr-subscribers-mlir-linalg Author: Hanumanth (Hanumanth04) Changes[mlir][linalg] Fix Linalg runtime verification pass to handle tensors with dimensions of size 0 Runtime verification on Linalg structured ops unconditionally computed The issue occurs because:
The fix is to guard all runtime verification with a check that ensures all loop ranges are non-empty (start < end) before performing any index arithmetic. Example MLIR that previously failed: func.func @<!-- -->fill_empty() -> tensor<0xi32> {
%c0 = arith.constant 0 : i32
%empty = tensor.empty() : tensor<0xi32>
%filled = linalg.fill ins(%c0 : i32) outs(%empty : tensor<0xi32>) -> tensor<0xi32>
return %filled : tensor<0xi32>
}Full diff: https://github.com/llvm/llvm-project/pull/163791.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 15eb51a6dcab2..737652c8cb9d1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -43,6 +44,32 @@ struct StructuredOpInterface
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value iterationDomainIsNonDegenerate;
+ for (auto [start, end] : llvm::zip(starts, ends)) {
+ auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
+ auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+
+ // Loop Trip count > 0 iff start < end
+ Value dimensionHasNonZeroTripCount = builder.create<index::CmpOp>(
+ loc, index::IndexCmpPredicate::SLT, startValue, endValue);
+
+ if (!iterationDomainIsNonDegenerate) {
+ iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
+ } else {
+ // Iteration domain is non-degenerate iff all dimensions have loop trip
+ // count > 0
+ iterationDomainIsNonDegenerate = builder.create<arith::AndIOp>(
+ loc, iterationDomainIsNonDegenerate, dimensionHasNonZeroTripCount);
+ }
+ }
+
+ if (!iterationDomainIsNonDegenerate)
+ return;
+
+ auto ifOp = builder.create<scf::IfOp>(loc, iterationDomainIsNonDegenerate,
+ /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,11 +137,11 @@ struct StructuredOpInterface
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
+ builder.setInsertionPointAfter(ifOp);
}
};
-template <typename... OpTs>
-void attachInterface(MLIRContext *ctx) {
+template <typename... OpTs> void attachInterface(MLIRContext *ctx) {
(OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
}
} // namespace
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
index 9f4393efc87bf..e48dcbd6c6110 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -103,6 +103,11 @@ func.func @main() {
// CHECK: unexpected negative result on dimension #0 of input/output operand #0
func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
+ %c0x = arith.constant dense<1.0> : tensor<0xf32>
+ %d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32>
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
+
return
}
@@ -297,3 +302,9 @@ func.func @reverse_from_3(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
+
+func.func @fill_empty_1d(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
|
|
@llvm/pr-subscribers-mlir Author: Hanumanth (Hanumanth04) Changes[mlir][linalg] Fix Linalg runtime verification pass to handle tensors with dimensions of size 0 Runtime verification on Linalg structured ops unconditionally computed The issue occurs because:
The fix is to guard all runtime verification with a check that ensures all loop ranges are non-empty (start < end) before performing any index arithmetic. Example MLIR that previously failed: func.func @<!-- -->fill_empty() -> tensor<0xi32> {
%c0 = arith.constant 0 : i32
%empty = tensor.empty() : tensor<0xi32>
%filled = linalg.fill ins(%c0 : i32) outs(%empty : tensor<0xi32>) -> tensor<0xi32>
return %filled : tensor<0xi32>
}Full diff: https://github.com/llvm/llvm-project/pull/163791.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 15eb51a6dcab2..737652c8cb9d1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -43,6 +44,32 @@ struct StructuredOpInterface
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value iterationDomainIsNonDegenerate;
+ for (auto [start, end] : llvm::zip(starts, ends)) {
+ auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
+ auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+
+ // Loop Trip count > 0 iff start < end
+ Value dimensionHasNonZeroTripCount = builder.create<index::CmpOp>(
+ loc, index::IndexCmpPredicate::SLT, startValue, endValue);
+
+ if (!iterationDomainIsNonDegenerate) {
+ iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
+ } else {
+ // Iteration domain is non-degenerate iff all dimensions have loop trip
+ // count > 0
+ iterationDomainIsNonDegenerate = builder.create<arith::AndIOp>(
+ loc, iterationDomainIsNonDegenerate, dimensionHasNonZeroTripCount);
+ }
+ }
+
+ if (!iterationDomainIsNonDegenerate)
+ return;
+
+ auto ifOp = builder.create<scf::IfOp>(loc, iterationDomainIsNonDegenerate,
+ /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,11 +137,11 @@ struct StructuredOpInterface
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
+ builder.setInsertionPointAfter(ifOp);
}
};
-template <typename... OpTs>
-void attachInterface(MLIRContext *ctx) {
+template <typename... OpTs> void attachInterface(MLIRContext *ctx) {
(OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...);
}
} // namespace
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
index 9f4393efc87bf..e48dcbd6c6110 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -103,6 +103,11 @@ func.func @main() {
// CHECK: unexpected negative result on dimension #0 of input/output operand #0
func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
+ %c0x = arith.constant dense<1.0> : tensor<0xf32>
+ %d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32>
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
+
return
}
@@ -297,3 +302,9 @@ func.func @reverse_from_3(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
+
+func.func @fill_empty_1d(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
Tagging @ryanpholt, could you please take a look? |
7f694ee to
f658bd3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me -- thanks!
f658bd3 to
43f73d7
Compare
|
Hi @ryanpholt , @matthias-springer I don't have permission to merge the change. Could you please help merge this PR when you get a chance? Thanks! |
Runtime verification on Linalg structured ops unconditionally computed
end - 1to determine the last iteration index before composing indexing maps. This caused spurious "negative index" assertion failures while operating on empty tensors (tensors with a dimension of size 0).The issue occurs because:
Empty tensors create loop ranges [0, 0) with zero trip count
Computing end - 1 = 0 - 1 = -1 creates a fictitious negative index
The negative index check triggers even though no loop iterations occur
The fix is to guard all runtime verification with a check that ensures all loop ranges are non-empty (start < end) before performing any index arithmetic.
Example MLIR that previously failed: